Skip to content

Conversation

Angazenn
Copy link
Contributor

@Angazenn Angazenn commented May 12, 2025

What this PR does / why we need it?

  1. This PR introduces native all_to_all communication operator to fix allgather bugs when dp_size > 1. Besides, it adds a naive implementation of force-load-balance when doing profile runs.
  2. The operator npu_dequant_swiglu_quant only supports input hidden_states with dtype torch.int32. This tensor occupies space of global_bs * seq_len * topk * hidden_size, which might be very large as ep_size grows. Therefore we need to disable this operator and use original swiglu && quantize.

Does this PR introduce any user-facing change?

No.

How was this patch tested?

By performing offline inference:
image

angazenn added 2 commits May 12, 2025 21:42
Signed-off-by: angazenn <zengyanjia@huawei.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
@Angazenn Angazenn force-pushed the all2all branch 3 times, most recently from af2215a to d6cf7a1 Compare May 13, 2025 02:23
@Angazenn Angazenn changed the title [WIP]add all2all when dp_size > 1 [WIP]add all2all when dp_size > 1 && downgrade npu_dequant_swiglu_quant May 13, 2025
Signed-off-by: angazenn <zengyanjia@huawei.com>
@Angazenn Angazenn force-pushed the all2all branch 2 times, most recently from b2ca756 to e0ab8e0 Compare May 13, 2025 03:07
Signed-off-by: angazenn <zengyanjia@huawei.com>
angazenn added 4 commits May 13, 2025 14:18
Signed-off-by: angazenn <zengyanjia@huawei.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
dist.all_to_all_single(gather_sizes,
scatter_sizes,
group=ep_group.device_group)
scatter_size_list = scatter_sizes.cpu().tolist()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may introduce serious performance regression, please note that we will change this in future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

gather_dim, scatter_sizes,
gather_sizes)

def reduce_scatter(self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this if you do not use it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

attn_metadata = get_forward_context().attn_metadata
if attn_metadata is None:
# when profile runs, force experts load balance to avoid high memory
# consumption from 1 rank.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add more comments on this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

angazenn added 2 commits May 13, 2025 22:37
Signed-off-by: angazenn <zengyanjia@huawei.com>
Signed-off-by: angazenn <zengyanjia@huawei.com>
@Angazenn Angazenn force-pushed the all2all branch 2 times, most recently from 3349fbd to 30eafb9 Compare May 13, 2025 15:04
Signed-off-by: angazenn <zengyanjia@huawei.com>
@Angazenn Angazenn changed the title [WIP]add all2all when dp_size > 1 && downgrade npu_dequant_swiglu_quant [BugFix]add all2all when dp_size > 1 && downgrade npu_dequant_swiglu_quant May 14, 2025
@ganyi1996ppo ganyi1996ppo merged commit 1e67089 into vllm-project:main May 15, 2025
15 checks passed
yiz-liu added a commit to yiz-liu/vllm-ascend that referenced this pull request May 17, 2025
@ttanzhiqiang
Copy link
Contributor

This is the best solution for A3 performance. Is there a best solution for A2 performance?

ganyi1996ppo pushed a commit that referenced this pull request May 24, 2025
### What this PR does / why we need it?
This PR fixes two accuracy bugs incurred by PR #819 when running
deepseekv3 series models:
1. #819 adds `all_to_all` communication in quantized cases, but
`all_gather` && `reduce_scatter` are removed in both of quantized and
unquantized cases. When running unquantized deepseekv3 models with
`ep_size == world_size`, the moe modules fail to communicate. Therefore,
this PR adds `all_to_all` communication on unquantized situation to
solve this accuracy issue.
2. Use `ep_size` rather than `dp_size` to decide whether to use
`all_to_all` in moe.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

---------

Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
momo609 pushed a commit to momo609/vllm-ascend that referenced this pull request May 30, 2025
…oject#897)

### What this PR does / why we need it?
This PR fixes two accuracy bugs incurred by PR vllm-project#819 when running
deepseekv3 series models:
1. vllm-project#819 adds `all_to_all` communication in quantized cases, but
`all_gather` && `reduce_scatter` are removed in both of quantized and
unquantized cases. When running unquantized deepseekv3 models with
`ep_size == world_size`, the moe modules fail to communicate. Therefore,
this PR adds `all_to_all` communication on unquantized situation to
solve this accuracy issue.
2. Use `ep_size` rather than `dp_size` to decide whether to use
`all_to_all` in moe.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

---------

Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
Signed-off-by: wangxiaoxin (A) <w00664509@china.huawei.com>
zxdukki pushed a commit to zxdukki/vllm-ascend that referenced this pull request Jun 3, 2025
…oject#897)

### What this PR does / why we need it?
This PR fixes two accuracy bugs incurred by PR vllm-project#819 when running
deepseekv3 series models:
1. vllm-project#819 adds `all_to_all` communication in quantized cases, but
`all_gather` && `reduce_scatter` are removed in both of quantized and
unquantized cases. When running unquantized deepseekv3 models with
`ep_size == world_size`, the moe modules fail to communicate. Therefore,
this PR adds `all_to_all` communication on unquantized situation to
solve this accuracy issue.
2. Use `ep_size` rather than `dp_size` to decide whether to use
`all_to_all` in moe.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed with new added/existing test.

---------

Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
@Angazenn Angazenn deleted the all2all branch September 8, 2025 03:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants